# simulation.py
"""
Contains the unified simulation loop function.
Now tracks rounds with no allocation.
"""

import numpy as np
import random
import math
from utils import get_report_from_action, discretize
import config
# Import agent model from agent.py
from agent import QLearningAgent
# Import planner functions from mechanisms.py
from mechanisms import planner_algorithm1_dual_sgd, planner_second_price_auction
from mechanisms import planner_optimistic_ftrl, planner_ftrl

def run_unified_simulation(planner_func, agent_type, mechanism_name, eta=0.0):
    """
    Runs a unified simulation loop.

    Args:
        planner_func (callable): The planner mechanism function to use.
        agent_type (str): 'q_learning' or 'truthful'.
        mechanism_name (str): Name of the mechanism for printing/plotting.
        eta (float): Step-size for primal-dual-based algorithms' update (if used).

    Returns:
        tuple: (total_welfare_history, agent_utility_history, total_cost_history,
                final_B_history, final_mu_history, no_allocation_counts_history,
                final_agents_or_none)
    """
    print(f"\nRunning Unified Simulation: {mechanism_name} with {agent_type} agents...")
    if planner_func == planner_algorithm1_dual_sgd or planner_func == planner_optimistic_ftrl or planner_func == planner_ftrl:
        print(f"  Using eta = {eta:.4f} for dual updates.")

    # --- Initialization ---
    agents = None
    if agent_type == 'q_learning':
        agents = [QLearningAgent(i) for i in range(config.K)]
    
    # --- Training Phase ---
    total_welfare_history = []
    agent_utility_history = [[] for _ in range(config.K)]
    total_cost_history = []
    final_B_history = []
    final_mu_history = []
    no_allocation_counts_history = []
    total_mu_t_stats = []

    for episode in range(config.NUM_EPISODES):
        episode_total_welfare = 0
        episode_agent_utilities = np.zeros(config.K)
        episode_total_cost = np.zeros(config.COST_DIM)
        episode_no_allocation_count = 0
        episode_mu_t_stats = []

        current_B = config.T * config.RHO * np.ones(config.COST_DIM)
        current_mu = np.zeros(config.COST_DIM)

        agent_states_for_q_update = [None] * config.K
        agent_actions_for_q_update = [None] * config.K

        report_u_history = []
        consumption_b_history = []
        allocation_i_history = []

        for t in range(config.T):
            private_values_v = []
            public_consumptions_b = []
            for i in range(config.K):
                private_values_v.append(np.random.uniform(
                    config.MIN_VALUE_RANGE[i],
                    config.MAX_VALUE_RANGE[i]
                ))
                public_consumptions_b.append(np.maximum(0, np.random.uniform(
                    np.asarray(config.MIN_ITEM_CONSUMPTION_FACTOR[i]) * config.RHO,
                    np.asarray(config.MAX_ITEM_CONSUMPTION_FACTOR[i]) * config.RHO
                )))

            reports_u = []

            if agent_type == 'q_learning':
                for i in range(config.K):
                    state = agents[i].get_state(private_values_v[i], t, current_mu)
                    action_idx = agents[i].choose_action(state)
                    report = get_report_from_action(action_idx, agents[i].num_report_actions, agents[i].report_range)
                    agent_states_for_q_update[i] = state
                    agent_actions_for_q_update[i] = action_idx
                    reports_u.append(report)
            else: # truthful agents
                reports_u = private_values_v

            allocated_agent_index = -1
            payments_vector = np.zeros(config.K)
            consumption_realized = np.zeros(config.COST_DIM)
            next_B = current_B
            next_mu = current_mu
            mu_t_stats = None

            if planner_func == planner_algorithm1_dual_sgd:
                planner_state_input = {'mu': current_mu, 'B': current_B}
                allocated_agent_index, payments_vector, consumption_realized, \
                next_mu, next_B = planner_func(reports_u, public_consumptions_b, planner_state_input, eta)
            elif planner_func == planner_second_price_auction:
                allocated_agent_index, payments_vector, consumption_realized, \
                next_B = planner_func(reports_u, public_consumptions_b, current_B)
            elif planner_func == planner_optimistic_ftrl:
                planner_state_input = {'t': t, 'mu': current_mu, 'B': current_B,
                                       'history_i': allocation_i_history,
                                       'history_u': report_u_history,
                                       'history_b': consumption_b_history}
                allocated_agent_index, payments_vector, consumption_realized, \
                next_mu, next_B, mu_t_stats = planner_func(reports_u, public_consumptions_b, planner_state_input, eta)
            elif planner_func == planner_ftrl:
                planner_state_input = {'t': t, 'mu': current_mu, 'B': current_B,
                                       'history_i': allocation_i_history,
                                       'history_u': report_u_history,
                                       'history_b': consumption_b_history}
                allocated_agent_index, payments_vector, consumption_realized, \
                next_mu, next_B = planner_func(reports_u, public_consumptions_b, planner_state_input, eta)
            
            # Update history statistics
            report_u_history.append(reports_u)
            consumption_b_history.append(public_consumptions_b)

            current_B = next_B
            current_mu = next_mu

            allocation_i_history.append(allocated_agent_index)
            if allocated_agent_index == -1:
                episode_no_allocation_count += 1
            else:
                episode_total_welfare += private_values_v[allocated_agent_index]
            episode_total_cost += consumption_realized

            if mu_t_stats is not None:
                episode_mu_t_stats.append(mu_t_stats)

            for i in range(config.K):
                agent_reward_this_round = (private_values_v[i] - payments_vector[i]) * (1 if allocated_agent_index == i else 0)
                episode_agent_utilities[i] += (config.GAMMA**t) * (private_values_v[i] * (1 if allocated_agent_index == i else 0))
                
                if agent_type == 'q_learning':
                    next_q_agent_state = None
                    if t < config.T - 1:
                        next_v_simulated = random.uniform(*config.VALUE_RANGE)
                        next_q_agent_state = agents[i].get_state(next_v_simulated, t + 1, current_mu)
                    
                    if agent_states_for_q_update[i] is not None and agent_actions_for_q_update[i] is not None:
                        agents[i].update_q_table(agent_states_for_q_update[i], agent_actions_for_q_update[i], 
                                                 agent_reward_this_round, next_q_agent_state)
        
        total_welfare_history.append(episode_total_welfare)
        total_cost_history.append(episode_total_cost)
        final_B_history.append(current_B.copy())
        final_mu_history.append(current_mu.copy())
        no_allocation_counts_history.append(episode_no_allocation_count)
        total_mu_t_stats.append((0, 0) if len(episode_mu_t_stats) == 0 else np.mean(episode_mu_t_stats, axis=0))
        for i in range(config.K):
             agent_utility_history[i].append(episode_agent_utilities[i])
        
        if agent_type == 'q_learning':
            for agent_obj in agents:
                agent_obj.decay_epsilon()

        if (episode + 1) % (config.NUM_EPISODES // 10) == 0:
             avg_welfare_last_chunk = np.mean(total_welfare_history[-(config.NUM_EPISODES // 10):]) if len(total_welfare_history) >= (config.NUM_EPISODES // 10) else np.mean(total_welfare_history)
             print(f"  Ep {episode+1}/{config.NUM_EPISODES} ({mechanism_name} / {agent_type}). AvgWelfare (last {config.NUM_EPISODES // 10}): {avg_welfare_last_chunk:.2f}, No Alloc: {episode_no_allocation_count}")
    
    print(f"Simulation Complete for {mechanism_name} with {agent_type} agents.")
    final_agents_or_none = agents if agent_type == 'q_learning' else None


    return (total_welfare_history, agent_utility_history, total_cost_history,
            final_B_history, final_mu_history, no_allocation_counts_history,
            final_agents_or_none, total_mu_t_stats)
